{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "48de73c0",
   "metadata": {},
   "source": [
    "# 互相关运算"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "bc918d0a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from d2l import torch as d2l"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "96703eb0",
   "metadata": {},
   "source": [
    "注意，输出大小略小于输入大小。这是因为卷积核的宽度和高度大于1，\n",
    "而卷积核只与图像中每个大小完全适合的位置进行互相关运算。\n",
    "所以，输出大小等于输入大小$n_h \\times n_w$减去卷积核大小$k_h \\times k_w$，即：\n",
    "\n",
    "$$(n_h-k_h+1) \\times (n_w-k_w+1).$$\n",
    "\n",
    "这是因为我们需要足够的空间在图像上“移动”卷积核。稍后，我们将看到如何通过在图像边界周围填充零来保证有足够的空间移动卷积核，从而保持输出大小不变。\n",
    "接下来，我们在`corr2d`函数中实现如上过程，该函数接受输入张量`X`和卷积核张量`K`，并返回输出张量`Y`。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "6db5776e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def corr2d(X, K):\n",
    "    '''计算二维互相关运算'''\n",
    "    h, w = K.shape # K是卷积核\n",
    "    Y = torch.zeros(X.shape[0] - h + 1, X.shape[1] - w + 1)\n",
    "    for i in range(Y.shape[0]):\n",
    "        for j in range(Y.shape[1]):\n",
    "            Y[i, j] = (X[i: i + h, j: j + w] * K).sum()\n",
    "    return Y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "57c0a606",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[19., 25.],\n",
       "        [37., 43.]])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X = torch.arange(9).reshape(3, 3)\n",
    "K = torch.arange(4).reshape(2, 2)\n",
    "corr2d(X, K)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f93735f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Conv2D(nn.Module):\n",
    "    def __init__(self, kernel_size):\n",
    "        super().__init__()\n",
    "        self.weight = nn.Parameter(torch.rand(kernel_size))\n",
    "        self.bias = nn.Parameter(torch.zeros(1))\n",
    "        \n",
    "    def forward(self, X):\n",
    "        return corr2d(X, self.weight) + self.bias"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "19837250",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1., 1., 0., 0., 0., 0., 1., 1.],\n",
       "        [1., 1., 0., 0., 0., 0., 1., 1.],\n",
       "        [1., 1., 0., 0., 0., 0., 1., 1.],\n",
       "        [1., 1., 0., 0., 0., 0., 1., 1.],\n",
       "        [1., 1., 0., 0., 0., 0., 1., 1.],\n",
       "        [1., 1., 0., 0., 0., 0., 1., 1.]])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X = torch.ones(6, 8)\n",
    "X[:, 2: 6] = 0\n",
    "X"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "4b6affc8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 1., -1.]])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "K = torch.tensor([[1.0, -1.0]])\n",
    "K"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "62b2b7df",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.,  1.,  0.,  0.,  0., -1.,  0.],\n",
       "        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],\n",
       "        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],\n",
       "        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],\n",
       "        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],\n",
       "        [ 0.,  1.,  0.,  0.,  0., -1.,  0.]])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Y = corr2d(X, K)\n",
    "Y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "b1691ce0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0., 0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0., 0.]])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "corr2d(X.t(), K)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "f24a7e0f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "batch 2, loss  1.362\n",
      "batch 4, loss  0.261\n",
      "batch 6, loss  0.057\n",
      "batch 8, loss  0.015\n",
      "batch 10, loss  0.005\n"
     ]
    }
   ],
   "source": [
    "conv2d = nn.Conv2d(1, 1, kernel_size = (1, 2), bias=False)\n",
    "\n",
    "X = X.reshape(1, 1, 6, 8) \n",
    "Y = Y.reshape(1, 1, 6, 7)\n",
    "\n",
    "for i in range(10):\n",
    "    Y_hat = conv2d(X)\n",
    "    l = (Y_hat - Y)**2\n",
    "    conv2d.zero_grad()\n",
    "    l.sum().backward()\n",
    "    conv2d.weight.data[:] -= 3e-2 * conv2d.weight.grad\n",
    "    if (i + 1) % 2 == 0:\n",
    "        print(f'batch {i+1}, loss {l.sum(): .3f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "7968e693",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 1.0006, -0.9879]])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "conv2d.weight.data.reshape(1, 2)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python(d2l)",
   "language": "python",
   "name": "d2l"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
